{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "using JSON\n",
    "using Serialization\n",
    "using DelimitedFiles\n",
    "using Distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "knnPrecision (generic function with 1 method)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(\"utils.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"/home/xiucheng/Github/t2vec/data\""
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datapath = \"/home/xiucheng/Github/t2vec/data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param = JSON.parsefile(\"../hyper-parameters.json\")\n",
    "regionps = param[\"region\"]\n",
    "cityname = regionps[\"cityname\"]\n",
    "cellsize = regionps[\"cellsize\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building spatial region with:\n",
      "        cityname=porto,\n",
      "        minlon=-8.735152,\n",
      "        minlat=40.953673,\n",
      "        maxlon=-8.156309,\n",
      "        maxlat=41.307945,\n",
      "        xstep=100.0,\n",
      "        ystep=100.0,\n",
      "        minfreq=100\n",
      "Reading parameter file from /home/xiucheng/Github/t2vec/data/porto-param-cell100\n",
      "Loaded /home/xiucheng/Github/t2vec/data/porto-param-cell100 into region\n"
     ]
    }
   ],
   "source": [
    "region = SpatialRegion(cityname,\n",
    "                       regionps[\"minlon\"], regionps[\"minlat\"],\n",
    "                       regionps[\"maxlon\"], regionps[\"maxlat\"],\n",
    "                       cellsize, cellsize,\n",
    "                       regionps[\"minfreq\"], # minfreq\n",
    "                       40_000, # maxvocab_size\n",
    "                       10, # k\n",
    "                       4)\n",
    "\n",
    "println(\"Building spatial region with:\n",
    "        cityname=$(region.name),\n",
    "        minlon=$(region.minlon),\n",
    "        minlat=$(region.minlat),\n",
    "        maxlon=$(region.maxlon),\n",
    "        maxlat=$(region.maxlat),\n",
    "        xstep=$(region.xstep),\n",
    "        ystep=$(region.ystep),\n",
    "        minfreq=$(region.minfreq)\")\n",
    "\n",
    "paramfile = \"$datapath/$(region.name)-param-cell$(Int(cellsize))\"\n",
    "if isfile(paramfile)\n",
    "    println(\"Reading parameter file from $paramfile\")\n",
    "    region = deserialize(paramfile)\n",
    "    println(\"Loaded $paramfile into region\")\n",
    "else\n",
    "    println(\"Cannot find $paramfile\")\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp1 Similar search without downsampling and distorting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"/home/xiucheng/Github/t2vec/data/exp1-trj.h5\""
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## create querydb \n",
    "prefix = \"exp1\"\n",
    "do_split = true\n",
    "start = 1_000_000+20_000\n",
    "num_query = 1000\n",
    "num_db = 100_000\n",
    "querydbfile = joinpath(datapath, \"$prefix-querydb.h5\")\n",
    "tfile = joinpath(datapath, \"$prefix-trj.t\")\n",
    "labelfile = joinpath(datapath, \"$prefix-trj.label\")\n",
    "vecfile = joinpath(datapath, \"$prefix-trj.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "101000"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "createQueryDB(\"$datapath/$cityname.h5\", start, num_query, num_db,\n",
    "              (x, y)->(x, y),\n",
    "              (x, y)->(x, y);\n",
    "              do_split=do_split,\n",
    "              querydbfile=querydbfile)\n",
    "createTLabel(region, querydbfile; tfile=tfile, labelfile=labelfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "`\u001b[4mpython\u001b[24m \u001b[4mt2vec.py\u001b[24m \u001b[4m-mode\u001b[24m \u001b[4m2\u001b[24m \u001b[4m-vocab_size\u001b[24m \u001b[4m18864\u001b[24m \u001b[4m-checkpoint\u001b[24m \u001b[4m/home/xiucheng/Github/t2vec/data/best_model_gen.pt\u001b[24m \u001b[4m-prefix\u001b[24m \u001b[4mexp1\u001b[24m`\n",
      "Namespace(batch=128, bidirectional=True, bucketsize=[(20, 30), (30, 30), (30, 50), (50, 50), (50, 70), (70, 70), (70, 100), (100, 100)], checkpoint='/home/xiucheng/Github/t2vec/data/best_model_gen.pt', criterion_name='NLL', cuda=True, data='/home/xiucheng/Github/t2vec/data', discriminative_w=0.1, dist_decay_speed=0.8, dropout=0.2, embedding_size=256, epochs=15, generator_batch=32, hidden_size=256, knearestvocabs=None, learning_rate=0.001, max_grad_norm=5.0, max_length=200, max_num_line=20000000, mode=2, num_layers=3, prefix='exp1', pretrained_embedding=None, print_freq=50, save_freq=1000, start_iteration=0, t2vec_batch=256, use_discriminative=False, vocab_size=18864)\n",
      "=> loading checkpoint '/home/xiucheng/Github/t2vec/data/best_model_gen.pt'\n",
      "0: Encoding 256 trjs...\n",
      "100: Encoding 256 trjs...\n",
      "200: Encoding 256 trjs...\n",
      "300: Encoding 256 trjs...\n",
      "=> saving vectors into /home/xiucheng/Github/t2vec/data/exp1-trj.h5\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"/home/xiucheng/Github/t2vec/experiment\""
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint = joinpath(datapath, \"best_model.pt\")\n",
    "t2vec = `python t2vec.py -mode 2 -vocab_size 18864 -checkpoint $checkpoint -prefix $prefix`\n",
    "println(t2vec)\n",
    "\n",
    "cd(\"/home/xiucheng/Github/t2vec/\")\n",
    "run(t2vec)\n",
    "cd(\"/home/xiucheng/Github/t2vec/experiment\")\n",
    "pwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean rank: 2.214 with dbsize: 20000\n",
      "mean rank: 3.317 with dbsize: 40000\n",
      "mean rank: 4.532 with dbsize: 60000\n",
      "mean rank: 6.022 with dbsize: 80000\n",
      "mean rank: 7.224 with dbsize: 100000\n"
     ]
    }
   ],
   "source": [
    "## load vectors and labels\n",
    "vecs = h5open(vecfile, \"r\") do f\n",
    "    read(f[\"layer3\"])\n",
    "end\n",
    "label = readdlm(labelfile, Int)\n",
    "\n",
    "query, db = vecs[:, 1:num_query], vecs[:, num_query+1:end]\n",
    "queryLabel, dbLabel = label[1:num_query], label[num_query+1:end]\n",
    "query, db = [query[:, i] for i in 1:size(query, 2)], [db[:, i] for i in 1:size(db, 2)];\n",
    "\n",
    "# without discriminative loss\n",
    "dbsizes = [20_000, 40_000, 60_000, 80_000, 100_000]\n",
    "for dbsize in dbsizes\n",
    "    ranks = ranksearch(query, queryLabel, db[1:dbsize], dbLabel[1:dbsize], euclidean)\n",
    "    println(\"mean rank: $(mean(ranks)) with dbsize: $dbsize\")\n",
    "end\n",
    "# mean rank: 2.135 with dbsize: 20000\n",
    "# mean rank: 3.132 with dbsize: 40000\n",
    "# mean rank: 4.244 with dbsize: 60000\n",
    "# mean rank: 5.553 with dbsize: 80000\n",
    "# mean rank: 6.662 with dbsize: 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exp2 Similar search with downsampling\n",
    "\n",
    "### create querydb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"/home/xiucheng/Github/t2vec/data/exp2-r6-trj.h5\""
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rate = 0.6\n",
    "prefix = \"exp2-r$(Int(10rate))\"\n",
    "do_split = true\n",
    "start = 1_000_000+20_000\n",
    "num_query = 1000\n",
    "num_db = 100_000\n",
    "\n",
    "querydbfile = joinpath(datapath, \"$prefix-querydb.h5\")\n",
    "tfile = joinpath(datapath, \"$prefix-trj.t\")\n",
    "labelfile = joinpath(datapath, \"$prefix-trj.label\")\n",
    "vecfile = joinpath(datapath, \"$prefix-trj.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "101000"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# for rate in [0.2, 0.3, 0.4, 0.5]\n",
    "#     querydbfile = joinpath(datapath, \"$prefix-r$(Int(10rate))-querydb.h5\")\n",
    "#     tfile = joinpath(datapath, \"$prefix-r$(Int(10rate))-trj.t\")\n",
    "#     labelfile = joinpath(datapath, \"$prefix-r$(Int(10rate))-trj.label\")\n",
    "#     vecfile = joinpath(datapath, \"$prefix-r$(Int(10rate))-trj.h5\")\n",
    "#     createQueryDB(\"$datapath/$cityname.h5\", start, num_query, num_db,\n",
    "#               (x, y)->downsampling(x, y, rate),\n",
    "#               (x, y)->downsampling(x, y, rate);\n",
    "#               do_split=do_split,\n",
    "#               querydbfile=querydbfile)\n",
    "#     createTLabel(region, querydbfile; tfile=tfile, labelfile=labelfile)\n",
    "# end\n",
    "\n",
    "createQueryDB(\"$datapath/$cityname.h5\", start, num_query, num_db,\n",
    "              (x, y)->downsampling(x, y, rate),\n",
    "              (x, y)->downsampling(x, y, rate);\n",
    "              do_split=do_split,\n",
    "              querydbfile=querydbfile)\n",
    "createTLabel(region, querydbfile; tfile=tfile, labelfile=labelfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "`\u001b[4mpython\u001b[24m \u001b[4mt2vec.py\u001b[24m \u001b[4m-mode\u001b[24m \u001b[4m2\u001b[24m \u001b[4m-vocab_size\u001b[24m \u001b[4m18864\u001b[24m \u001b[4m-checkpoint\u001b[24m \u001b[4m/home/xiucheng/Github/t2vec/data/best_model_gen.pt\u001b[24m \u001b[4m-prefix\u001b[24m \u001b[4mexp2-r6\u001b[24m`\n",
      "Namespace(batch=128, bidirectional=True, bucketsize=[(20, 30), (30, 30), (30, 50), (50, 50), (50, 70), (70, 70), (70, 100), (100, 100)], checkpoint='/home/xiucheng/Github/t2vec/data/best_model_gen.pt', criterion_name='NLL', cuda=True, data='/home/xiucheng/Github/t2vec/data', discriminative_w=0.1, dist_decay_speed=0.8, dropout=0.2, embedding_size=256, epochs=15, generator_batch=32, hidden_size=256, knearestvocabs=None, learning_rate=0.001, max_grad_norm=5.0, max_length=200, max_num_line=20000000, mode=2, num_layers=3, prefix='exp2-r6', pretrained_embedding=None, print_freq=50, save_freq=1000, start_iteration=0, t2vec_batch=256, use_discriminative=False, vocab_size=18864)\n",
      "=> loading checkpoint '/home/xiucheng/Github/t2vec/data/best_model_gen.pt'\n",
      "0: Encoding 256 trjs...\n",
      "100: Encoding 256 trjs...\n",
      "200: Encoding 256 trjs...\n",
      "300: Encoding 256 trjs...\n",
      "=> saving vectors into /home/xiucheng/Github/t2vec/data/exp2-r6-trj.h5\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"/home/xiucheng/Dropbox/code/t2vec+/experiment\""
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint = joinpath(datapath, \"best_model_gen.pt\")\n",
    "t2vec = `python t2vec.py -mode 2 -vocab_size 18864 -checkpoint $checkpoint -prefix $prefix`\n",
    "println(t2vec)\n",
    "\n",
    "cd(\"/home/xiucheng/Github/t2vec/\")\n",
    "run(t2vec)\n",
    "cd(\"/home/xiucheng/Github/t2vec/experiment\")\n",
    "pwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean rank: 16.451 with dbsize: 100000\n"
     ]
    }
   ],
   "source": [
    "vecs = h5open(vecfile, \"r\") do f\n",
    "    read(f[\"layer3\"])\n",
    "end\n",
    "label = readdlm(labelfile, Int)\n",
    "\n",
    "query, db = vecs[:, 1:num_query], vecs[:, num_query+1:end]\n",
    "queryLabel, dbLabel = label[1:num_query], label[num_query+1:end]\n",
    "query, db = [query[:, i] for i in 1:size(query, 2)], [db[:, i] for i in 1:size(db, 2)];\n",
    "\n",
    "# without discriminative loss\n",
    "dbsize = 100_000\n",
    "ranks = ranksearch(query, queryLabel, db[1:dbsize], dbLabel[1:dbsize], euclidean)\n",
    "println(\"mean rank: $(mean(ranks)) with dbsize: $dbsize\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.3.0",
   "language": "julia",
   "name": "julia-1.3"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.3.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
